Extending scatter! to work with CUDA sparse arrays#648
Extending scatter! to work with CUDA sparse arrays#648CarloLucibello merged 5 commits intoFluxML:masterfrom
Conversation
|
The integration test with Lux fail because |
|
Seems fine. Is it possible to add a test on CI somehow, perhaps in https://github.com/FluxML/NNlib.jl/blob/master/test/ext_cuda/scatter.jl ? |
|
I added the sparse matrix varieties to the list of array types that are automatically tested @mcabbott Any thoughts on making the implementation less ugly? Should I open an issue on CUDA.jl suggesting making CUSPARSE arrays subtypes of |
|
ops, I shouldn't have merged this @alonsoC1s , scatter tests are failing |
|
@CarloLucibello Turns out I don't need a GPU to test. I'm working on the fix |
Aims to fix #647 by extending the signature of
scatter!to work withAbstractCuSparseArray, a CUDA array type notably excluded by the original method. With the proposed patch, callingscatter!with sparse arrays fromCUDA.CUSPARSEwill correctly call the CUDA-specialized method instead of calling the generic CPU method, which triggered a scalar indexing error. In my testing the existing CUDA kernels work perfectly fine withCuSparseArrayCSC.The proposed implementation, perhaps inelegantly, just expands the types in the signature with
Union{...}. I am open to discussing more beautiful ways of implementing this. Ideally,AbstractCuSparseArraywould be a subtype ofAnyCuArray.PR Checklist